{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to Encrypted Tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "\n",
    "# python 3.7 is required\n",
    "assert sys.version_info[0] == 3 and sys.version_info[1] == 7, \"python 3.7 is required\"\n",
    "\n",
    "import crypten\n",
    "crypten.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a CrypTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = crypten.cryptensor([1.0, 2.0, 3.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPCTensor(\n",
       "\t_tensor=tensor([ 65536, 131072, 196608])\n",
       "\tplain_text=HIDDEN\n",
       "\tptype=ptype.arithmetic\n",
       ")"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 2., 3.])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.get_plain_text()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = crypten.cryptensor(torch.tensor([4.0, 5.0, 6.0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Note this is a single party case to demo the API. In a single party setting, tensors are encoded, but not encrypted since secret shares can't be exchanged with other parties."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Operations on CrypTensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPCTensor(\n",
       "\t_tensor=tensor([196608, 262144, 327680])\n",
       "\tplain_text=HIDDEN\n",
       "\tptype=ptype.arithmetic\n",
       ")"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x + 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([3., 4., 5.])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(x + 2.0).get_plain_text()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPCTensor(\n",
       "\t_tensor=tensor([327680, 458752, 589824])\n",
       "\tplain_text=HIDDEN\n",
       "\tptype=ptype.arithmetic\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x + y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPCTensor(\n",
       "\t_tensor=2097152\n",
       "\tplain_text=HIDDEN\n",
       "\tptype=ptype.arithmetic\n",
       ")"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.dot(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a full list of the supported operations see the [docs](https://crypten.readthedocs.io/en/latest/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: Compute Mean Squared Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x MPCTensor(\n",
      "\t_tensor=tensor([ 65536, 131072, 196608])\n",
      "\tplain_text=HIDDEN\n",
      "\tptype=ptype.arithmetic\n",
      ")\n",
      "y MPCTensor(\n",
      "\t_tensor=tensor([262144, 327680, 393216])\n",
      "\tplain_text=HIDDEN\n",
      "\tptype=ptype.arithmetic\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(\"x\", x)\n",
    "print(\"y\", y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(9.)\n"
     ]
    }
   ],
   "source": [
    "squared_loss = (x - y)**2\n",
    "mean_squared_loss = squared_loss.mean()\n",
    "\n",
    "print(mean_squared_loss.get_plain_text())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch Version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_pytorch = torch.tensor([1.0, 2.0, 3.0])\n",
    "y_pytorch = torch.tensor([4.0, 5.0, 6.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(9.)\n"
     ]
    }
   ],
   "source": [
    "squared_loss_pytorch = (x_pytorch - y_pytorch)**2\n",
    "print(squared_loss_pytorch.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "crypten-pycon",
   "language": "python",
   "name": "crypten-pycon"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
